using Microsoft.ML;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;
using System;
using System.Collections.Generic;
using System.Linq;

namespace AutoMLExample
{
    public static class Program
    {
        const string trainDataFilename = "BiopsyTrainData.csv";
        const string modelFilename = "Model.bin";
        const uint trainingTimeInSeconds = 10;

        static List<TrainingResult> trainingResults = new List<TrainingResult>();

        public static void Main(string[] args)
        {
            MLContext mlContext = new MLContext();

            IDataView trainDataView = mlContext.Data.LoadFromTextFile<BiopsyData>(trainDataFilename, ';', true);

            var experimentSettings = new BinaryExperimentSettings();
            experimentSettings.MaxExperimentTimeInSeconds = trainingTimeInSeconds;

            var metricValues = Enum.GetValues(typeof(BinaryClassificationMetric));

            // Loop through all metrics
            foreach (BinaryClassificationMetric metricValue in metricValues)
            {
                experimentSettings.OptimizingMetric = metricValue;

                var experiment = mlContext.Auto().CreateBinaryClassificationExperiment(experimentSettings);
                ExperimentResult<BinaryClassificationMetrics> experimentResult;

                try
                {
                    // Run training
                    experimentResult = experiment.Execute(trainDataView, "class");
                }
                // Workaround for bug in Execute() Method
                catch (ArgumentOutOfRangeException)
                {
                    continue;
                }

                var bestRun = experimentResult.BestRun;
                var metrics = bestRun.ValidationMetrics;

                trainingResults.Add(new TrainingResult
                {
                    ExperimentTimeInSeconds = trainingTimeInSeconds,
                    Model = bestRun.Model,
                    Metric = metricValue,
                    Trainer = bestRun.TrainerName,
                    Accuracy = metrics.Accuracy,
                    AreaUnderPrecisionRecallCurve = metrics.AreaUnderPrecisionRecallCurve,
                    AreaUnderRocCurve = metrics.AreaUnderRocCurve,
                    F1Score = metrics.F1Score,
                    PositiveRecall = metrics.PositiveRecall,
                    NegativeRecall = metrics.NegativeRecall,
                    PositivePrecision = metrics.PositivePrecision,
                    NegativePrecision = metrics.NegativePrecision,
                    Matrix = metrics.ConfusionMatrix
                });
            }

            trainingResults = trainingResults
                .OrderByDescending(result => result.NegativePrecision)
                .ThenByDescending(result => result.PositivePrecision)
                .ToList<TrainingResult>();

            // Save best model to disk
            var bestTrainingResults = trainingResults[0];
            mlContext.Model.Save(bestTrainingResults.Model, trainDataView.Schema, "Model.bin");

            // Print results to console window
            Console.WriteLine($"Training Time: {bestTrainingResults.ExperimentTimeInSeconds}");
            Console.WriteLine($"Metric: {bestTrainingResults.Metric.ToString()}");
            Console.WriteLine($"Trainer: {bestTrainingResults.Trainer}");
            Console.WriteLine($"Accuracy: {bestTrainingResults.Accuracy:0.###}");
            Console.WriteLine($"AreaUnderPrecisionRecallCurve: {bestTrainingResults.AreaUnderPrecisionRecallCurve:0.###}");
            Console.WriteLine($"AreaUnderRocCurve: {bestTrainingResults.AreaUnderRocCurve:0.###}");
            Console.WriteLine($"F1Score: {bestTrainingResults.F1Score:0.###}");
            Console.WriteLine($"PositiveRecall: {bestTrainingResults.PositiveRecall:0.###}");
            Console.WriteLine($"NegativeRecall: {bestTrainingResults.NegativeRecall:0.###}");
            Console.WriteLine($"Positive Precision: {bestTrainingResults.PositivePrecision:0.###}");
            Console.WriteLine($"Negative Precision: {bestTrainingResults.NegativePrecision:0.###}");
            Console.WriteLine();
            Console.WriteLine(bestTrainingResults.Matrix.GetFormattedConfusionTable());
        }
    }
}
